// Copyright 2022 The Google Research Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef THIRD_PARTY_DNN_ACCELERATOR_HLS_SYSTEMC_PROCESSINGELEMENT_H_
#define THIRD_PARTY_DNN_ACCELERATOR_HLS_SYSTEMC_PROCESSINGELEMENT_H_

#include <ac_std_float.h>
#include <mc_connections.h>
#include <systemc.h>

#include "src/AccelTypes.h"

template <typename IDTYPE, typename WDTYPE, typename ODTYPE>
SC_MODULE(ProcessingElement) {
 private:
  sc_signal<WDTYPE> updatedWeight;
  WDTYPE weightReg;
  WDTYPE weightFifo;

 public:
  sc_in<bool> CCS_INIT_S1(clk);
  sc_in<bool> CCS_INIT_S1(rstn);

  sc_in<WDTYPE> CCS_INIT_S1(weightIn);
  sc_in<bool> CCS_INIT_S1(weightValid);
  sc_out<WDTYPE> CCS_INIT_S1(weightOut);

  Connections::In<ac_int<1, false> > CCS_INIT_S1(weightSwapIn);
  Connections::Out<ac_int<1, false> > CCS_INIT_S1(weightSwapOut);

  Connections::In<IDTYPE> CCS_INIT_S1(inputIn);
  Connections::In<ODTYPE> CCS_INIT_S1(psumIn);

  Connections::Out<IDTYPE> CCS_INIT_S1(inputOut);
  Connections::Out<ODTYPE> CCS_INIT_S1(psumOut);

  SC_CTOR(ProcessingElement) {
    SC_THREAD(processWeights);
    sensitive << clk.pos();
    async_reset_signal_is(rstn, false);

    SC_THREAD(run);
    sensitive << clk.pos();
    async_reset_signal_is(rstn, false);
  }

  void processWeights() {
    weightOut.write(WDTYPE());

    wait();

    while (true) {
      if (weightValid.read()) {
        weightFifo = weightIn.read();
        updatedWeight = weightFifo;
      }
      weightOut.write(weightFifo);

      wait();
    }
  }

  void run() {
    inputIn.Reset();
    psumIn.Reset();
    inputOut.Reset();
    psumOut.Reset();
    weightSwapIn.Reset();
    weightSwapOut.Reset();

    wait();

#pragma hls_pipeline_init_interval 1
#pragma hls_pipeline_stall_mode flush
    while (true) {
      IDTYPE input = inputIn.Pop();
      ODTYPE psum = psumIn.Pop();
      ac_int<1, false> weightSwap = weightSwapIn.Pop();

      if (weightSwap) {
        weightReg = updatedWeight;
      }

      ODTYPE output = fma(input, weightReg, psum);

      inputOut.Push(input);
      psumOut.Push(output);
      weightSwapOut.Push(weightSwap);
    }
  }

  ac::bfloat16 fma(ac::bfloat16 input, ac::bfloat16 weight, ac::bfloat16 psum) {
    // CCS_LOG(input << "*" << weight << "+" << psum);
    return input.fma<AC_RND_CONV, false>(weight, psum);
  }

  template <typename DTYPE>
  DTYPE fma(DTYPE input, DTYPE weight, DTYPE psum) {
    // CCS_LOG(input << "*" << weight << "+" << psum);
    return input * weight + psum;
  }
};

#endif  // THIRD_PARTY_DNN_ACCELERATOR_HLS_SRC_PROCESSINGELEMENT_H
